#!/usr/bin/env python3
import json
from pathlib import Path
import sys; sys.path.append(str(Path(__file__).parent.parent.resolve()))

from manual_sweep import defaults

def create_sweep(fname, envs, env2experts_list, env2ase_sigma, algorithms, learner_pis, seeds, pggae=False):
    # mamba and lops-aps
    lines = []
    for seed in seeds:
        for env in envs:
            env_name = f'dmc:{env}-v1'
            print('env', env, 'env_name', env_name)
            for experts in env2experts_list[env]:
                for learner_pi, algorithm in zip(learner_pis, algorithms):
                    for ase_sigma in env2ase_sigma[env]:
                        assert algorithm == 'lops-aps-ase'
                        lines.append(
                            {
                                "env_name": env_name,
                                "load_expert_step": experts,
                                "algorithm": algorithm,
                                "use_riro_for_learner_pi": learner_pi,
                                "ase_sigma": ase_sigma,
                                "seed": seed,
                                **defaults
                            }
                        )
            if pggae:
                # pg-gae
                lines.append(
                    {
                        "env_name": env_name,
                        "load_expert_step": [0],
                        "algorithm": "pg-gae",
                        "use_riro_for_learner_pi": "none",
                        "ase_sigma": 0,
                        "seed": seed,
                        **defaults
                    }
                )

    json_text = [json.dumps(line, sort_keys=True) for line in lines]
    print(f'{len(json_text)} lines to {fname}')
    with open(fname, 'w') as f:
        f.write('\n'.join(json_text))


if __name__ == '__main__':
    import sys
    this_file_name = sys.argv[0]

    # Variables to sweep over
    envs = ['Cheetah-run']
    env2experts_list = {
        'Cheetah-run': [[100], [100, 70], [100, 70, 40], [100, 70, 40, 20]],
    }
    env2ase_sigma = {
        'Cheetah-run': [0.5 * (i+1) for i in range(20)],
    }

    seeds = [i for i in range(3)]
    learner_pis = ['rollin', 'all']
    algorithms = ['lops-aps-ase', 'lops-aps-ase']

    fname = Path(this_file_name).stem + '.jsonl'
    create_sweep(fname, envs, env2experts_list, env2ase_sigma, algorithms, learner_pis, seeds)
